package templates

import (
	"bytes"
	"errors"
	"fmt"
	"go/types"
	"io/fs"
	"os"
	"path/filepath"
	"reflect"
	"regexp"
	"runtime"
	"sort"
	"strconv"
	"strings"
	"sync"
	"text/template"
	"unicode"

	"github.com/99designs/gqlgen/internal/code"
	"github.com/99designs/gqlgen/internal/imports"
)

// CurrentImports keeps track of all the import declarations that are needed during the execution of
// a plugin. this is done with a global because subtemplates currently get called in functions. Lets
// aim to remove this eventually.
var CurrentImports *Imports

// Options specify various parameters to rendering a template.
type Options struct {
	// PackageName is a helper that specifies the package header declaration.
	// In other words, when you write the template you don't need to specify `package X`
	// at the top of the file. By providing PackageName in the Options, the Render
	// function will do that for you.
	PackageName string
	// Template is a string of the entire template that
	// will be parsed and rendered. If it's empty,
	// the plugin processor will look for .gotpl files
	// in the same directory of where you wrote the plugin.
	Template string

	// Use the go:embed API to collect all the template files you want to pass into Render
	// this is an alternative to passing the Template option
	TemplateFS fs.FS

	// Filename is the name of the file that will be
	// written to the system disk once the template is rendered.
	Filename        string
	RegionTags      bool
	GeneratedHeader bool
	// PackageDoc is documentation written above the package line
	PackageDoc string
	// FileNotice is notice written below the package line
	FileNotice string
	// Data will be passed to the template execution.
	Data  any
	Funcs template.FuncMap

	// Packages cache, you can find me on config.Config
	Packages *code.Packages
}

var (
	modelNamesMu sync.Mutex
	modelNames   = make(map[string]string, 0)
	goNameRe     = regexp.MustCompile("[^a-zA-Z0-9_]")
)

// Render renders a gql plugin template from the given Options. Render is an
// abstraction of the text/template package that makes it easier to write gqlgen
// plugins. If Options.Template is empty, the Render function will look for `.gotpl`
// files inside the directory where you wrote the plugin.
func Render(cfg Options) error {
	if CurrentImports != nil {
		panic(errors.New("recursive or concurrent call to RenderToFile detected"))
	}
	CurrentImports = &Imports{packages: cfg.Packages, destDir: filepath.Dir(cfg.Filename)}

	funcs := Funcs()
	for n, f := range cfg.Funcs {
		funcs[n] = f
	}

	t := template.New("").Funcs(funcs)
	t, err := parseTemplates(cfg, t)
	if err != nil {
		return err
	}

	roots := make([]string, 0, len(t.Templates()))
	for _, templ := range t.Templates() {
		// templates that end with _.gotpl are special files we don't want to include
		if strings.HasSuffix(templ.Name(), "_.gotpl") ||
			// filter out templates added with {{ template xxx }} syntax inside the template file
			!strings.HasSuffix(templ.Name(), ".gotpl") {
			continue
		}

		roots = append(roots, templ.Name())
	}

	// then execute all the important looking ones in order, adding them to the same file
	sort.SliceStable(roots, func(i, j int) bool {
		// important files go first
		if strings.HasSuffix(roots[i], "!.gotpl") &&
			!strings.HasSuffix(roots[j], "!.gotpl") {
			return true
		}
		if strings.HasSuffix(roots[j], "!.gotpl") &&
			!strings.HasSuffix(roots[i], "!.gotpl") {
			return false
		}
		// files that have identical names are sorted dependent on input order
		// so we rely on SliceStable here to ensure deterministic results
		// to avoid test failures
		return roots[i] < roots[j]
	})

	var buf bytes.Buffer
	for _, root := range roots {
		if cfg.RegionTags {
			buf.WriteString("\n// region    " + center(70, "*", " "+root+" ") + "\n")
		}
		err := t.Lookup(root).Execute(&buf, cfg.Data)
		if err != nil {
			return fmt.Errorf("%s: %w", root, err)
		}
		if cfg.RegionTags {
			buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n")
		}
	}

	var result bytes.Buffer
	if cfg.GeneratedHeader {
		result.WriteString("// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\n")
	}
	if cfg.PackageDoc != "" {
		result.WriteString(cfg.PackageDoc + "\n")
	}
	result.WriteString("package ")
	result.WriteString(cfg.PackageName)
	result.WriteString("\n\n")
	if cfg.FileNotice != "" {
		result.WriteString(cfg.FileNotice)
		result.WriteString("\n\n")
	}
	result.WriteString("import (\n")
	result.WriteString(CurrentImports.String())
	result.WriteString(")\n")
	_, err = buf.WriteTo(&result)
	if err != nil {
		return err
	}
	CurrentImports = nil

	err = write(cfg.Filename, result.Bytes(), cfg.Packages)
	if err != nil {
		return err
	}

	cfg.Packages.Evict(code.ImportPathForDir(filepath.Dir(cfg.Filename)))
	return nil
}

func parseTemplates(cfg Options, t *template.Template) (*template.Template, error) {
	if cfg.Template != "" {
		var err error
		t, err = t.New("template.gotpl").Parse(cfg.Template)
		if err != nil {
			return nil, fmt.Errorf("error with provided template: %w", err)
		}
		return t, nil
	}

	var fileSystem fs.FS
	if cfg.TemplateFS != nil {
		fileSystem = cfg.TemplateFS
	} else {
		// load path relative to calling source file
		_, callerFile, _, _ := runtime.Caller(2)
		rootDir := filepath.Dir(callerFile)
		fileSystem = os.DirFS(rootDir)
	}

	t, err := t.ParseFS(fileSystem, "*.gotpl")
	if err != nil {
		return nil, fmt.Errorf("locating templates: %w", err)
	}

	return t, nil
}

func center(width int, pad, s string) string {
	if len(s)+2 > width {
		return s
	}
	lpad := (width - len(s)) / 2
	rpad := width - (lpad + len(s))
	return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad)
}

func Funcs() template.FuncMap {
	return template.FuncMap{
		"ucFirst":            UcFirst,
		"lcFirst":            LcFirst,
		"quote":              strconv.Quote,
		"rawQuote":           rawQuote,
		"dump":               Dump,
		"ref":                ref,
		"obj":                obj,
		"ts":                 TypeIdentifier,
		"call":               Call,
		"dict":               dict,
		"prefixLines":        prefixLines,
		"notNil":             notNil,
		"strSplit":           StrSplit,
		"reserveImport":      CurrentImports.Reserve,
		"lookupImport":       CurrentImports.Lookup,
		"go":                 ToGo,
		"goPrivate":          ToGoPrivate,
		"goModelName":        ToGoModelName,
		"goPrivateModelName": ToGoPrivateModelName,
		"add": func(a, b int) int {
			return a + b
		},
		"render": func(filename string, tpldata any) (*bytes.Buffer, error) {
			return render(resolveName(filename, 0), tpldata)
		},
	}
}

func UcFirst(s string) string {
	if s == "" {
		return ""
	}
	r := []rune(s)
	r[0] = unicode.ToUpper(r[0])
	return string(r)
}

func LcFirst(s string) string {
	if s == "" {
		return ""
	}

	r := []rune(s)
	r[0] = unicode.ToLower(r[0])
	return string(r)
}

func isDelimiter(c rune) bool {
	return c == '-' || c == '_' || unicode.IsSpace(c)
}

func ref(p types.Type) string {
	typeString := CurrentImports.LookupType(p)
	// TODO(steve): figure out why this is needed
	// otherwise inconsistent sometimes
	// see https://github.com/99designs/gqlgen/issues/3414#issuecomment-2822856422
	if typeString == "interface{}" {
		return "any"
	}
	if typeString == "map[string]interface{}" {
		return "map[string]any"
	}
	// assuming that some other container interface{} type
	// like []interface{} or something needs coercion to any
	if strings.Contains(typeString, "interface{}") {
		return strings.ReplaceAll(typeString, "interface{}", "any")
	}
	return typeString
}

func obj(obj types.Object) string {
	pkg := CurrentImports.Lookup(obj.Pkg().Path())
	if pkg != "" {
		pkg += "."
	}

	return pkg + obj.Name()
}

func Call(p *types.Func) string {
	pkg := CurrentImports.Lookup(p.Pkg().Path())

	if pkg != "" {
		pkg += "."
	}

	if p.Type() != nil {
		// make sure the returned type is listed in our imports.
		ref(p.Type().(*types.Signature).Results().At(0).Type())
	}

	return pkg + p.Name()
}

func dict(values ...any) (map[string]any, error) {
	if len(values)%2 != 0 {
		return nil, errors.New("invalid dict call: arguments must be key-value pairs")
	}
	m := make(map[string]any, len(values)/2)
	for i := 0; i < len(values); i += 2 {
		key, ok := values[i].(string)
		if !ok {
			return nil, errors.New("dict keys must be strings")
		}
		m[key] = values[i+1]
	}
	return m, nil
}

func resetModelNames() {
	modelNamesMu.Lock()
	defer modelNamesMu.Unlock()
	modelNames = make(map[string]string, 0)
}

func buildGoModelNameKey(parts []string) string {
	const sep = ":"
	return strings.Join(parts, sep)
}

func goModelName(primaryToGoFunc func(string) string, parts []string) string {
	modelNamesMu.Lock()
	defer modelNamesMu.Unlock()

	var (
		goNameKey string
		partLen   int

		nameExists = func(n string) bool {
			for _, v := range modelNames {
				if n == v {
					return true
				}
			}
			return false
		}

		applyToGoFunc = func(parts []string) string {
			var out string
			switch len(parts) {
			case 0:
				return ""
			case 1:
				return primaryToGoFunc(parts[0])
			default:
				out = primaryToGoFunc(parts[0])
			}
			for _, p := range parts[1:] {
				out = fmt.Sprintf("%s%s", out, ToGo(p))
			}
			return out
		}

		applyValidGoName = func(parts []string) string {
			var out string
			for _, p := range parts {
				out = fmt.Sprintf("%s%s", out, replaceInvalidCharacters(p))
			}
			return out
		}
	)

	// build key for this entity
	goNameKey = buildGoModelNameKey(parts)

	// determine if we've seen this entity before, and reuse if so
	if goName, ok := modelNames[goNameKey]; ok {
		return goName
	}

	// attempt first pass
	if goName := applyToGoFunc(parts); !nameExists(goName) {
		modelNames[goNameKey] = goName
		return goName
	}

	// determine number of parts
	partLen = len(parts)

	// if there is only 1 part, append incrementing number until no conflict
	if partLen == 1 {
		base := applyToGoFunc(parts)
		for i := 0; ; i++ {
			tmp := fmt.Sprintf("%s%d", base, i)
			if !nameExists(tmp) {
				modelNames[goNameKey] = tmp
				return tmp
			}
		}
	}

	// best effort "pretty" name
	for i := partLen - 1; i >= 1; i-- {
		tmp := fmt.Sprintf("%s%s", applyToGoFunc(parts[0:i]), applyValidGoName(parts[i:]))
		if !nameExists(tmp) {
			modelNames[goNameKey] = tmp
			return tmp
		}
	}

	// finally, fallback to just adding an incrementing number
	base := applyToGoFunc(parts)
	for i := 0; ; i++ {
		tmp := fmt.Sprintf("%s%d", base, i)
		if !nameExists(tmp) {
			modelNames[goNameKey] = tmp
			return tmp
		}
	}
}

func ToGoModelName(parts ...string) string {
	return goModelName(ToGo, parts)
}

func ToGoPrivateModelName(parts ...string) string {
	return goModelName(ToGoPrivate, parts)
}

func replaceInvalidCharacters(in string) string {
	return goNameRe.ReplaceAllLiteralString(in, "_")
}

func wordWalkerFunc(private bool, nameRunes *[]rune) func(*wordInfo) {
	return func(info *wordInfo) {
		word := info.Word

		switch {
		case private && info.WordOffset == 0:
			if strings.ToUpper(word) == word || strings.ToLower(word) == word {
				// ID → id, CAMEL → camel
				word = strings.ToLower(info.Word)
			} else {
				// ITicket → iTicket
				word = LcFirst(info.Word)
			}

		case info.MatchCommonInitial:
			word = strings.ToUpper(word)

		case !info.HasCommonInitial && (strings.ToUpper(word) == word || strings.ToLower(word) == word):
			// FOO or foo → Foo
			// FOo → FOo
			word = UcFirst(strings.ToLower(word))
		}

		*nameRunes = append(*nameRunes, []rune(word)...)
	}
}

func ToGo(name string) string {
	if name == "_" {
		return "_"
	}
	runes := make([]rune, 0, len(name))

	wordWalker(name, wordWalkerFunc(false, &runes))

	return string(runes)
}

func ToGoPrivate(name string) string {
	if name == "_" {
		return "_"
	}
	runes := make([]rune, 0, len(name))

	wordWalker(name, wordWalkerFunc(true, &runes))

	return sanitizeKeywords(string(runes))
}

type wordInfo struct {
	WordOffset         int
	Word               string
	MatchCommonInitial bool
	HasCommonInitial   bool
}

// This function is based on the following code.
// https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
func wordWalker(str string, f func(*wordInfo)) {
	runes := []rune(strings.TrimFunc(str, isDelimiter))
	w, i, wo := 0, 0, 0 // index of start of word, scan, word offset
	hasCommonInitial := false
	for i+1 <= len(runes) {
		eow := false // whether we hit the end of a word
		switch {
		case i+1 == len(runes):
			eow = true
		case isDelimiter(runes[i+1]):
			// underscore; shift the remainder forward over any run of underscores
			eow = true
			n := 1
			for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) {
				n++
			}

			// Leave at most one underscore if the underscore is between two digits
			if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
				n--
			}

			copy(runes[i+1:], runes[i+n+1:])
			runes = runes[:len(runes)-n]
		case unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]):
			// lower->non-lower
			eow = true
		}
		i++

		initialisms := GetInitialisms()
		// [w,i) is a word.
		word := string(runes[w:i])
		if !eow && initialisms[word] && !unicode.IsLower(runes[i]) {
			// through
			// split IDFoo → ID, Foo
			// but URLs → URLs
		} else if !eow {
			if initialisms[word] {
				hasCommonInitial = true
			}
			continue
		}

		matchCommonInitial := false
		upperWord := strings.ToUpper(word)
		if initialisms[upperWord] {
			// If the uppercase word (string(runes[w:i]) is "ID" or "IP"
			// AND
			// the word is the first two characters of the current word
			// AND
			// that is not the end of the word
			// AND
			// the length of the remaining string is greater than 3
			// AND
			// the third rune is an uppercase one
			// THEN
			// do NOT count this as an initialism.
			switch upperWord {
			case "ID", "IP":
				if remainingRunes := runes[w:]; word == string(remainingRunes[:2]) && !eow &&
					len(remainingRunes) > 3 &&
					unicode.IsUpper(remainingRunes[3]) {
					continue
				}
			}
			hasCommonInitial = true
			matchCommonInitial = true
		}

		f(&wordInfo{
			WordOffset:         wo,
			Word:               word,
			MatchCommonInitial: matchCommonInitial,
			HasCommonInitial:   hasCommonInitial,
		})
		hasCommonInitial = false
		w = i
		wo++
	}
}

var keywords = []string{
	"break",
	"default",
	"func",
	"interface",
	"select",
	"case",
	"defer",
	"go",
	"map",
	"struct",
	"chan",
	"else",
	"goto",
	"package",
	"switch",
	"const",
	"fallthrough",
	"if",
	"range",
	"type",
	"continue",
	"for",
	"import",
	"return",
	"var",
	"_",
}

// sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions
func sanitizeKeywords(name string) string {
	for _, k := range keywords {
		if name == k {
			return name + "Arg"
		}
	}
	return name
}

func rawQuote(s string) string {
	return "`" + strings.ReplaceAll(s, "`", "`+\"`\"+`") + "`"
}

func notNil(field string, data any) bool {
	v := reflect.ValueOf(data)

	if v.Kind() == reflect.Ptr {
		v = v.Elem()
	}
	if v.Kind() != reflect.Struct {
		return false
	}
	val := v.FieldByName(field)

	return val.IsValid() && !val.IsNil()
}

func StrSplit(s, sep string) []string {
	return strings.Split(s, sep)
}

func Dump(val any) string {
	switch val := val.(type) {
	case int:
		return strconv.Itoa(val)
	case int64:
		return strconv.FormatInt(val, 10)
	case float64:
		return fmt.Sprintf("%f", val)
	case string:
		return strconv.Quote(val)
	case bool:
		return strconv.FormatBool(val)
	case nil:
		return "nil"
	case []any:
		var parts []string
		for _, part := range val {
			parts = append(parts, Dump(part))
		}
		return "[]any{" + strings.Join(parts, ",") + "}"
	case map[string]any:
		buf := bytes.Buffer{}
		buf.WriteString("map[string]any{")
		var keys []string
		for key := range val {
			keys = append(keys, key)
		}
		sort.Strings(keys)

		for _, key := range keys {
			data := val[key]

			buf.WriteString(strconv.Quote(key))
			buf.WriteString(":")
			buf.WriteString(Dump(data))
			buf.WriteString(",")
		}
		buf.WriteString("}")
		return buf.String()
	default:
		panic(fmt.Errorf("unsupported type %T", val))
	}
}

func prefixLines(prefix, s string) string {
	return prefix + strings.ReplaceAll(s, "\n", "\n"+prefix)
}

func resolveName(name string, skip int) string {
	if name[0] == '.' {
		// load path relative to calling source file
		_, callerFile, _, _ := runtime.Caller(skip + 1)
		return filepath.Join(filepath.Dir(callerFile), name[1:])
	}

	// load path relative to this directory
	_, callerFile, _, _ := runtime.Caller(0)
	return filepath.Join(filepath.Dir(callerFile), name)
}

func render(filename string, tpldata any) (*bytes.Buffer, error) {
	t := template.New("").Funcs(Funcs())

	b, err := os.ReadFile(filename)
	if err != nil {
		return nil, err
	}

	t, err = t.New(filepath.Base(filename)).Parse(string(b))
	if err != nil {
		panic(err)
	}

	buf := &bytes.Buffer{}
	return buf, t.Execute(buf, tpldata)
}

func write(filename string, b []byte, packages *code.Packages) error {
	err := os.MkdirAll(filepath.Dir(filename), 0o755)
	if err != nil {
		return fmt.Errorf("failed to create directory: %w", err)
	}

	formatted, err := imports.Prune(filename, b, packages)
	if err != nil {
		fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error())
		formatted = b
	}

	err = os.WriteFile(filename, formatted, 0o644)
	if err != nil {
		return fmt.Errorf("failed to write %s: %w", filename, err)
	}

	return nil
}

var pkgReplacer = strings.NewReplacer(
	"/", "ᚋ",
	".", "ᚗ",
	"-", "ᚑ",
	"~", "א",
)

func TypeIdentifier(t types.Type) string {
	res := ""
	for {
		switch it := code.Unalias(t).(type) {
		case *types.Pointer:
			t.Underlying()
			res += "ᚖ"
			t = it.Elem()
		case *types.Slice:
			res += "ᚕ"
			t = it.Elem()
		case *types.Named:
			res += pkgReplacer.Replace(it.Obj().Pkg().Path())
			res += "ᚐ"
			res += it.Obj().Name()
			return res
		case *types.Basic:
			res += it.Name()
			return res
		case *types.Map:
			res += "map"
			return res
		case *types.Interface:
			res += "interface"
			return res
		default:
			panic(fmt.Errorf("unexpected type %T", it))
		}
	}
}

// CommonInitialisms is a set of common initialisms.
// Only add entries that are highly unlikely to be non-initialisms.
// For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
var CommonInitialisms = map[string]bool{
	"ACL":   true,
	"API":   true,
	"ASCII": true,
	"CPU":   true,
	"CSS":   true,
	"CSV":   true,
	"DNS":   true,
	"EOF":   true,
	"GUID":  true,
	"HTML":  true,
	"HTTP":  true,
	"HTTPS": true,
	"ICMP":  true,
	"ID":    true,
	"IP":    true,
	"JSON":  true,
	"KVK":   true,
	"LHS":   true,
	"PDF":   true,
	"PGP":   true,
	"QPS":   true,
	"QR":    true,
	"RAM":   true,
	"RHS":   true,
	"RPC":   true,
	"SLA":   true,
	"SMTP":  true,
	"SQL":   true,
	"SSH":   true,
	"SVG":   true,
	"TCP":   true,
	"TLS":   true,
	"TTL":   true,
	"UDP":   true,
	"UI":    true,
	"UID":   true,
	"URI":   true,
	"URL":   true,
	"UTF8":  true,
	"UUID":  true,
	"VM":    true,
	"XML":   true,
	"XMPP":  true,
	"XSRF":  true,
	"XSS":   true,
	"AWS":   true,
	"GCP":   true,
}

// GetInitialisms returns the initialisms to capitalize in Go names. If unchanged, default
// initialisms will be returned
var GetInitialisms = func() map[string]bool {
	return CommonInitialisms
}
